from gym.spaces import Discrete, Box, MultiDiscrete, Dict
import numpy as np
import gym
import random
import re
import math
from utils import get_all_wrappers
from collections import OrderedDict

class CostPerturbationWrapper(gym.Wrapper):

    def __init__(
            self,
            env,
    ):
        super(CostPerturbationWrapper, self).__init__(env)
        self.costs = [1.0,
                     (self.env.action_price_space[0] + self.env.action_price_space[1]) / 2,
                     (self.env.action_price_space[1] + self.env.action_price_space[2]) / 2,
                     (self.env.action_price_space[2] + self.env.action_price_space[3]) / 2
                     ]
        self.episode_counter = 0


    def reset(self):
        self.env.c_i = self.costs[self.episode_counter % len(self.costs)]
        self.episode_counter = self.episode_counter+1
        return self.env.reset()



class BertrandCompetitionDiscreteEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(
            self, 
            num_agents=2,
            c_i=1,
            a=2,
            gamma=1,
            dp_type='pdp',
            a_0=0,
            mu=0.25,
            m=15,
            adv=0.3,
            k=1,
            max_steps=100000,
            grid_upper_bound=2.1,
            is_evaluation_env = False,
    ):

        super(BertrandCompetitionDiscreteEnv, self).__init__()
        self.num_agents = num_agents

        # self.bb_size = bb_size
        self.is_evaluation_env = is_evaluation_env

        # Length of Memory
        self.k = k

        # Marginal Cost
        self.c_i = c_i

        # Number of Discrete Prices
        self.m = m

        # Fraction of how much mechanism implements PDP/DPDP
        self.gamma = gamma

        # Sigma is smoothing parameter for
        self.sigma = 0.01

        # ADV value for dynamic pdp
        self.adv = adv

        # Product Quality Indexes
        self.a = np.array([a] * num_agents)

        # Product Quality Index: Outside Good
        self.a_0 = a_0

        # Index of Horizontal Differentiation
        self.mu = mu

        # Buybox type
        self.dp_type = dp_type

        # MultiAgentEnv Action and Observation Space
        self.agents = ['agent_' + str(i) for i in range(num_agents)]
        self.observation_spaces = {}
        self.action_spaces = {}

        if k > 0:
            self.numeric_low = np.array([0] * (k * num_agents))
            numeric_high = np.array([m] * (k * num_agents))
            obs_space = Box(self.numeric_low, numeric_high, dtype=int)
        else:
            self.numeric_low = np.array([0] * num_agents)
            numeric_high = np.array([m] * num_agents)
            obs_space = Box(self.numeric_low, numeric_high, dtype=int)

        for agent in self.agents:
            self.observation_spaces[agent] = obs_space
            self.action_spaces[agent] = Discrete(m)

        self.action_price_space = np.linspace(0.95, grid_upper_bound, m)

        self.current_step = None
        self.max_steps = max_steps

        self.action_history = {}
        self.consumer_surplus = []
        self.demands = []

        for agent in self.agents:
            if agent not in self.action_history:
                self.action_history[agent] = [self.action_spaces[agent].sample()]

        self.bbx_occupant = [0]

        print("------------------------------------------")
        print("PARAMETERS IN OPERATION")
        print("number of agents: ", self.num_agents)
        print("memory length: ", k )
        print("number of prices m: ", self.m)

        print("marginal cost  c_i: ", self.c_i)
        print("product quality component a: ", self.a)
        print("a_0: ", self.a_0)
        print("mu: ", self.mu)
        print("prices: ", self.action_price_space)
        print("Gamma: ", self.gamma)
        print("Smoothing sigma: ", self.sigma)

        print("ADV: ", self.adv)
        print("max_steps: ", self.max_steps)
        print("dp_type: ", dp_type)
        print("------------------------------------------")


    def reset(self):
        self.current_step = 0

        # Reset to random action
        random_action = np.random.randint(self.m, size=self.num_agents)

        for i in range(random_action.size):
            self.action_history[self.agents[i]].append(random_action[i])

        if self.k > 0:
            # Fill in price history with random actions
            for _ in range(self.k):
                random_action = np.random.randint(self.m, size=self.num_agents)
                for i in range(self.num_agents):
                    self.action_history[self.agents[i]].append(random_action[i])
            obs_agents = np.array([self.action_history[self.agents[i]][-self.k:] for i in range(self.num_agents)], dtype=object).flatten()

            obs_list = list(obs_agents)
            obs_agents = np.array(obs_list)

            observation = dict(zip(self.agents, [obs_agents for _ in range(self.num_agents)]))
        else:
            observation = dict(zip(self.agents, [self.numeric_low for _ in range(self.num_agents)]))

        return observation


    def step(self, actions_dict):

        info = {}
        self.current_step += 1

        if self.current_step % 100000 == 0:
            print("BertrandCompetitionDiscreteEnv steps completed: ", self.current_step)

        # First, add prices to history
        actions_idx = np.array(list(actions_dict.values())).flatten()
        for i in range(self.num_agents):
            self.action_history[self.agents[i]].append(actions_idx[i])

        # Then, create observation
        if self.k > 0:
            obs_agents = np.array([self.action_history[self.agents[i]][-self.k:] for i in range(self.num_agents)], dtype=object).flatten()
            observation = dict(zip(self.agents, [obs_agents for _ in range(self.num_agents)]))
        else:
            observation = dict(zip(self.agents, [self.numeric_low for _ in range(self.num_agents)]))

        self.prices_idx = [int(pr) for pr in actions_idx[:self.num_agents]]
        self.prices = self.action_price_space.take(self.prices_idx)

        if self.dp_type == 'learn_general':
            supervisor_action = self.action_history['supervisor'][-1]
            self.supervisor_action = [int(char_num) for char_num in list(str(supervisor_action))]
        elif self.dp_type == 'pdp' or self.dp_type == 'dpdp' or self.dp_type == 'no_intervene':
            self.supervisor_action = None

        new_bbx_idx = self.get_bbx_idx(self.prices, self.supervisor_action)

        occupants = [new_bbx_idx[k] for k in range(len(new_bbx_idx))]
        if len(occupants) == 1:
            occupants = occupants[0]

        self.bbx_occupant.append(occupants)

        # Compute demands and surpluses
        c_surplus = self.compute_surplus(self.prices, new_bbx_idx)
        self.consumer_surplus.append(c_surplus)
        info['surplus'] = c_surplus

        demands = []
        rewards = []
        for i in range(self.num_agents):
            demand_i = self.demand(self.a, self.prices, self.mu, i, new_bbx_idx)
            reward_i = (self.prices[i] - self.c_i) * demand_i
            demands.append(demand_i)
            rewards.append(reward_i)
        self.demands.append(demands)
        reward = dict(zip(self.agents, rewards))

        # Check if done
        if self.current_step == self.max_steps:
            print("reached max_steps: ", self.current_step == self.max_steps)
            done = {'__all__': True}
        else:
            done = {'__all__': False}

        return observation, reward, done, info


    def set_regulator_action(self, action):
        self.supervisor_action=action


    def compute_surplus(self, prices, bbx_idx):

        assert len(prices) == self.num_agents
        bbx_surp = np.sum([np.exp((self.a[i] - float(prices[i])) / self.mu) for i in bbx_idx]) + np.exp(self.a_0 / self.mu)
        bbx_surp = self.mu * np.log(bbx_surp)

        normal_surp = np.sum([np.exp((self.a[j] - float(prices[j])) / self.mu) for j in range(self.num_agents)]) + np.exp(self.a_0 / self.mu)
        normal_surp = self.mu * np.log(normal_surp)

        surp = self.gamma * bbx_surp + (1.0 - self.gamma) * normal_surp

        return surp


    def compute_simple_reward(self, prices, bbx_idx):

        assert len(prices) == self.num_agents
        bbx_surp = np.sum([np.exp(-float(prices[i])) for i in bbx_idx])

        return bbx_surp


    def demand(self, a, p, mu, agent_idx, bb_idx):
        ''' Demand as a function of product quality indexes, price, and mu. '''
        normal_demand = np.exp((a[agent_idx] - p[agent_idx]) / mu) / (np.sum(np.exp((a - p) / mu)) + np.exp(self.a_0 / mu))

        if bb_idx == None:
            return normal_demand

        else:
            if not agent_idx in bb_idx:
                bbx_demand = 0

            else:
                denom = np.sum([np.exp((a[idx] - p[idx]) / mu) for idx in bb_idx]) + np.exp(self.a_0 / mu)
                bbx_demand = np.exp((a[agent_idx] - p[agent_idx]) / mu) / denom

            demand = self.gamma * bbx_demand + (1.0 - self.gamma) * normal_demand

            return demand


    def gen_final_metrics(self):

        last_n = self.convergence
        last_surplus_n = self.consumer_surplus[-last_n:]

        weighted_prices, tot_demands = np.zeros(last_n), np.zeros(last_n)

        all_demands = np.array(self.demands[-last_n:])

        for agent in self.agents:
            last_prices = self.action_price_space.take(self.action_history[agent][-last_n:])

            part_1, part_2 = agent.split('_')

            agent_num = int(part_2)

            last_n_demands = np.transpose(all_demands)[agent_num]

            tot_demands += last_n_demands

            weighted_prices += last_prices * last_n_demands

        non_zero_idx = np.flatnonzero(tot_demands)
        weighted_prices = weighted_prices[non_zero_idx]
        tot_demands = tot_demands[non_zero_idx]

        weighted_prices = weighted_prices/tot_demands
        share_weighted_price = np.mean(weighted_prices)
        average_final_CS = np.mean(last_surplus_n)

        return share_weighted_price, average_final_CS, self.num_converged_steps


    def render(self, mode='human'):
        raise NotImplementedError


    def get_bbx_idx(self, prices, supervisor_action):

        if self.dp_type == 'no_intervene':
            bb_size = self.num_agents
            return [int(i) for i in range(bb_size)]

        elif self.dp_type == 'pdp':
            bb_size = 1
            prices_np = np.sort(prices)
            assert bb_size <= self.num_agents
            return [int(np.where(prices == prices_np[i])[0][0]) for i in range(bb_size)]

        elif self.dp_type == 'dpdp':
            bb_size = 1
            prices_np = np.sort(prices)
            prev_prices_idx = [self.action_history[self.agents[i]][-2] for i in range(self.num_agents)]
            prev_prices = self.action_price_space.take(prev_prices_idx)

            bbx_occupant_idx = self.bbx_occupant[-1]
            bbx_occupant_price = prices[bbx_occupant_idx]
            bbx_occupant_prev_price = prev_prices[bbx_occupant_idx]

            occupant_price_not_rise = bbx_occupant_price <= bbx_occupant_prev_price

            non_bbx_idx = int(1 - bbx_occupant_idx)

            undercut_price_diff = prices[bbx_occupant_idx] - prices[non_bbx_idx]

            if undercut_price_diff < self.adv and occupant_price_not_rise:
                return [bbx_occupant_idx]
            else:
                return [int(np.where(prices == prices_np[i])[0][0]) for i in range(bb_size)]

        elif self.dp_type == 'learn_exact_price':
            price_displayed = self.action_price_space[supervisor_action]
            return [i for i in range(self.num_agents) if prices[i] == price_displayed]

        elif self.dp_type == 'learn_threshold':
            price_thresh = self.action_price_space[supervisor_action]
            return [i for i in range(self.num_agents) if prices[i] <= price_thresh]

        elif self.dp_type == 'learn_general':
            supervisor_action = re.sub('\W+', '', str(list(supervisor_action)))
            supervisor_action = [int(char_num) for char_num in list(str(supervisor_action))]
            return [i for i in range(self.num_agents) if supervisor_action[i] == 1]



class RLSupervisorQPricingWrapper(gym.Wrapper):


    def __init__(
            self,
            env,
            alpha=0.15,
            delta=0.95,
            beta=0.00001,
            bbox_state_space_type='price_profile',
            logger = None,
            log_freq=0
    ):

        super(RLSupervisorQPricingWrapper, self).__init__(env)

        self.bbox_state_space_type = bbox_state_space_type
        self.logger = logger
        self.log_freq = log_freq
        self.check_for_convergence = False

        self.q_table = [{} for _ in range(self.env.num_agents)]
        self.delta = delta
        self.beta = beta
        self.alpha = alpha
        self.q_init() # Initialize agents' q_table

        self.prev_loop_count = 0
        self.this_step_mode = "standard" # Use 'random' or 'argmax' to let pricing agents pick random or argmax prices without updating Q matrices

        self.steps_since_restart = [0] * self.env.num_agents
        self.step_counter = 0

        # Setup observation space
        obs_space_type = self.get_observation_space_type()
        self.observation_space = Dict({
            'base_environment': MultiDiscrete(obs_space_type),
        })

        # Setup action space
        if self.env.dp_type == 'learn_exact_price' or self.env.dp_type == 'learn_threshold':
            self.action_space = Discrete(self.env.m)
        elif self.env.dp_type == 'learn_general':
            self.action_space = MultiDiscrete([2]*self.env.num_agents) # Action here is agent set function
        elif self.env.dp_type == 'dpdp' or self.env.dp_type == 'pdp' or self.env.dp_type == 'no_intervene':
            self.action_space = Discrete(self.env.num_agents)


    def reset(self):

        self.convergence_counter=0

        obs_sub_env = self.env.reset() # Restart sub_env
        init_prices = obs_sub_env['agent_0'] # Get prices quoted in the very first round

        self.current_obs_pricing_agents = init_prices
        self.current_action_pricing_agents = \
            self.get_pricing_agents_actions(str(self.current_obs_pricing_agents))

        observation = OrderedDict({"base_environment": self.convert_last_prices_to_obs()})

        return observation


    def step(self, action):

        self.step_counter = self.step_counter+1

        self.env.set_regulator_action(action)

        for i in range(self.env.num_agents): self.steps_since_restart[i] = self.steps_since_restart[i] + 1

        # Create action_dict using prices and buy_box action at step t
        actions_dict = self.current_action_pricing_agents.copy()

        if self.env.dp_type == "learn_general":
            actions_dict["supervisor"] = re.sub('\W+', '', str(list(action)))
        else:
            actions_dict["supervisor"] = action

        # Step in environment
        obs_sub_env, reward_pricing_agents, _, info = self.env.step(actions_dict)

        new_prices = np.array([int(pr) for pr in obs_sub_env["agent_0"][0:self.env.num_agents]])

        # New observation pricing agents includes prices at step t and it is needed to update q tables
        new_obs_pricing_agents = np.array([int(pr) for pr in obs_sub_env["agent_0"][0:self.env.num_agents]])

        # Update agents' tables
        info["johnson_convergence"] = False
        if self.this_step_mode == 'standard':
            old_choice_actions = [np.argmax(self.q_table[agent][str(new_prices)]) for agent in range(self.env.num_agents)]
            self.update_q_tables(str(new_obs_pricing_agents), actions_dict, reward_pricing_agents, str(self.current_obs_pricing_agents))
            if self.check_for_convergence:
                info["johnson_convergence"] = self.check_if_converged(str(new_prices), old_choice_actions)

        self.current_obs_pricing_agents = new_obs_pricing_agents
        self.current_action_pricing_agents = \
            self.get_pricing_agents_actions(str(self.current_obs_pricing_agents), action_type=self.this_step_mode) # Prices at step t+1

        self.this_step_mode = "standard"

        info["reward_pricing_agents"] = reward_pricing_agents
        reward = info["surplus"]
        observation = OrderedDict({"base_environment": self.convert_last_prices_to_obs()})

        if self.log_freq>0 and self.step_counter % self.log_freq == 0:
            self.log_info(action, info)

        return observation, reward, False, info


    def log_info(self, action, info):

        self.logger.record("count_steps", self.step_counter)
        self.logger.record("supervisor_action", action)
        self.logger.record("consumer_surplus", info["surplus"])
        self.logger.record("agent0_reward", info["reward_pricing_agents"]["agent_0"])
        self.logger.record("agent1_reward", info["reward_pricing_agents"]["agent_1"])
        self.logger.record("c_i", round(self.env.c_i, 2))

        for j in range(self.env.num_agents):
            self.logger.record("price_"+str(j), np.where(self.env.action_price_space == self.env.prices[j])[0][0])

        self.logger.dump()


    def q_init(self):

        for a_0 in range(self.env.m):
            for a_1 in range(self.env.m):
                observation = str(np.array([a_0, a_1]))

                for agent in range(self.env.num_agents):
                    if observation not in self.q_table[agent]:
                        self.q_table[agent][observation] = self.initialize_agents_q_table(observation)


    def q_matrices_to_norm_vec(self):
        q_matrices = np.empty((0))
        for q_table in self.q_table:
            q_table_vec = np.array(list(q_table.values())).flatten()
            max_abs = max(abs(q_table_vec))
            if max_abs>0: q_table_vec = q_table_vec / max_abs
            q_matrices = np.append(q_matrices, q_table_vec)
        return q_matrices


    def get_pricing_agents_actions(self, observation, action_type="standard"):

        pricing_agents_actions = {}

        if action_type == "random":
            for agent in range(self.num_agents):
                pricing_agents_actions[self.env.agents[agent]] = self.env.action_spaces[self.env.agents[agent]].sample()

        if action_type == "argmax":
            for agent in range(self.num_agents):
                pricing_agents_actions[self.env.agents[agent]] = np.argmax(self.q_table[agent][observation])

        if action_type == "standard":
            for agent in range(self.num_agents):
                epsilon = np.exp(-1 * self.beta * self.steps_since_restart[agent])
                if random.uniform(0, 1) < epsilon:
                    pricing_agents_actions[self.env.agents[agent]] = self.env.action_spaces[self.env.agents[agent]].sample()
                else:
                    pricing_agents_actions[self.env.agents[agent]] = np.argmax(self.q_table[agent][observation])

        return pricing_agents_actions


    def update_q_tables(self, next_observation, actions_dict, reward, prev_observation):
        last_values = [0] * self.env.num_agents
        Q_maxes = [0] * self.env.num_agents

        for agent in range(self.env.num_agents):
            if next_observation not in self.q_table[agent]:
                self.q_table[agent][next_observation] = self.initialize_agents_q_table(next_observation)

            last_values[agent] = self.q_table[agent][prev_observation][actions_dict[self.env.agents[agent]]]
            Q_maxes[agent] = np.max(self.q_table[agent][next_observation])

            self.q_table[agent][prev_observation][actions_dict[self.env.agents[agent]]] = \
                ((1 - self.alpha) * last_values[agent]) + (self.alpha * (reward[self.env.agents[agent]] + self.delta * Q_maxes[agent]))


    def check_if_converged(self, observation, old_choice_actions):

        new_choice_actions = [np.argmax(self.q_table[agent][observation]) for agent in range(self.env.num_agents)]
        action_diff = np.sum(np.absolute(np.array(new_choice_actions) - np.array(old_choice_actions)))
        q_table_stable = True if action_diff == 0 else False

        has_converged = False
        if q_table_stable:
            self.convergence_counter += 1
        else:
            self.convergence_counter = 0
        if self.convergence_counter == 100000:
            print("converged: True")
            has_converged = True
            print("convergence_counter:", self.convergence_counter)
            print("done: ", has_converged)
            print("-------------------------\n\n\n")

        return has_converged

    def new_initialize_agents_q_table(self, observation):

        matrix = []

        threshhold = self.model.predict(observation, deterministic=True)[0]

        print("threshhold: ", threshhold)

        for i in range(self.env.m):
            reward = 0
            for j in range(self.env.m):

                price_i = self.env.action_price_space[i]
                price_j = self.env.action_price_space[j]

                prices = np.array([price_i, price_j])

                if i <= threshhold and j <= threshhold:
                    #both agents are displayed in the buy-box, demand_i =  0.5 * bbx_demand
                    bbx_demand = self.q_demand(self.env.a, prices, self.env.mu, 0, [0,1])
                    demand_i = bbx_demand

                elif i <= threshhold and j > threshhold:
                    #i is in the buy-box but j is not
                    bbx_demand = self.q_demand(self.env.a, prices, self.env.mu, 0, [0])
                    demand_i = bbx_demand

                elif i > threshhold and j <= threshhold:
                    #i is not the buy-box but j is
                    reg_demand = self.q_demand(self.env.a, prices, self.env.mu, 0, [1])
                    demand_i = reg_demand

                elif i > threshhold and j > threshhold:
                    #both i and j are outside the buybox
                    reg_demand = self.q_demand(self.env.a, prices, self.env.mu, 0, [])
                    demand_i = reg_demand

                profit_i = (price_i - self.env.c_i) * demand_i

                reward += profit_i / self.m  # assume all actions by the other agent are uniformly possible

            matrix.append(float(reward / (1 - self.delta)))

        return matrix


    def initialize_agents_q_table(self, observation):

        matrix = []

        for i in range(self.env.m):
            reward = 0
            for j in range(self.env.m):

                price_i = self.env.action_price_space[i]
                price_j = self.env.action_price_space[j]
                prices = np.array([price_i, price_j])

                bbx_demand = self.q_demand(self.env.a, prices, self.env.mu, 0, [0])
                reg_demand = self.q_demand(self.env.a, prices, self.env.mu, 0, [1])

                if price_i < price_j:
                    # agent_i wins the buy_box
                    demand_i = bbx_demand

                elif price_i > price_j:
                    # agent_i loses the buy_box
                    demand_i = reg_demand

                elif price_i == price_j:
                    # agent_i has a probability bbx_probab of being in buybox according to equation 3 on Johnson paper
                    bbx_probab = self.smooth_ties(price_i, prices)
                    demand_i = bbx_probab * bbx_demand + (1 - bbx_probab) * reg_demand

                profit_i = (price_i - self.env.c_i) * demand_i

                reward += profit_i / self.m  # assume all actions by the other agent are uniformly possible

            matrix.append(float(reward / (1 - self.delta)))

        return matrix


    def q_demand(self, a, p, mu, agent_idx, bb_idx):
        ''' Demand as a function of product quality indexes, price, and mu. '''

        normal_demand = np.exp((a[agent_idx] - p[agent_idx]) / mu) / (
                    np.sum(np.exp((a - p) / mu)) + np.exp(self.a_0 / mu))

        if not agent_idx in bb_idx:
            bbx_demand = 0

        else:
            denom = np.sum([np.exp((a[idx] - p[idx]) / mu) for idx in bb_idx]) + np.exp(self.a_0 / mu)
            bbx_demand = np.exp((a[agent_idx] - p[agent_idx]) / mu) / denom

        demand = self.env.gamma * bbx_demand + (1.0 - self.env.gamma) * normal_demand

        return demand


    def smooth_ties(self, price_i, prices):
        sigma = 0.01
        numer = np.exp(-price_i / sigma)
        den = np.sum([np.exp(-pi / sigma) for pi in prices])
        return numer / den


    def get_observation_space_type(self):
        if self.bbox_state_space_type == "price_profile": return [self.m,self.m]
        if self.bbox_state_space_type == "no_state": return [1]
        if self.bbox_state_space_type == "binary": return [2]


    def convert_last_prices_to_obs(self):
        price_array = np.array([self.current_action_pricing_agents[self.agents[i]] for i in range(self.num_agents)])
        return self.adapt_price_array(price_array)


    def adapt_price_array(self, price_array):
        if self.bbox_state_space_type == "price_profile":
            return price_array
        if self.bbox_state_space_type == "no_state":
            return np.array([0])
        if self.bbox_state_space_type == "binary":
            collusion_flag = 1 if np.max(price_array)>3 else 0
            return np.array([collusion_flag])



class StackMDPWrapper(gym.Wrapper):

    def __init__(
            self,
            env,
            tot_num_eq_steps=1000,
            tot_num_reward_steps=10,
            frac_excluded_eq_steps=0,
            reward_step_random_price_prob=0,
            critic_obs="flag"
    ):

        super(StackMDPWrapper, self).__init__(env)

        # This sets the total number of equilibrium and reward steps in StackMDP
        self.tot_num_eq_steps = tot_num_eq_steps
        self.tot_num_reward_steps = tot_num_reward_steps
        self.do_log = True

        self.tot_num_steps = 0
        self.critic_obs = critic_obs

        self.frac_excluded_eq_steps = frac_excluded_eq_steps
        self.reward_step_random_price_prob = reward_step_random_price_prob

        # Here we need to enhance the observation space for Stackelberg MDP

        if self.critic_obs == "flag":
            self.observation_space = Dict({
                'base_environment': self.env.observation_space['base_environment'],
                'critic:is_reward_step': Discrete(2),
            })
        if self.critic_obs == "Q_matrix":
            num_q_entries = 0
            for q_table in self.q_table:
                num_q_entries = num_q_entries + len(np.array(list(q_table.values())).flatten())
            self.observation_space = Dict({
                'critic:is_reward_step': Discrete(2),
                'base_environment': self.env.observation_space['base_environment'],
                'critic:exploration_rates': Box(low=0, high=1.0, shape=(self.env.num_agents,)),
                'critic:Q_matrices': Box(low=-1.0, high=1.0, shape=(num_q_entries,)),
            })

        print("------------------------------------------")
        print("StackMDP ENV PARAMETERS")
        print("tot_num_eq_steps: ", tot_num_eq_steps)
        print("tot_num_reward_steps: ", tot_num_reward_steps)

        print("------------------------------------------")


    def reset(self):

        obs_sub_env = self.env.reset() # Restart sub_env

        self.eq_steps_counter = 0
        self.reward_steps_counter = 0

        # The following lines are needed to exclude steps from replay buffers
        from random import sample
        number_of_excluded_steps = int(self.frac_excluded_eq_steps * self.tot_num_eq_steps)
        self.excluded_indexes = sample(list(range(1, self.tot_num_eq_steps+1)), k=number_of_excluded_steps)
        self.excluded_indexes.sort()

        full_observation = OrderedDict({"base_environment": obs_sub_env['base_environment']})
        if self.critic_obs == "flag" or self.critic_obs == "Q_matrix":
            full_observation["critic:is_reward_step"] = 0
        if self.critic_obs == "Q_matrix":
            full_observation["critic:Q_matrices"] = self.q_matrices_to_norm_vec()
            full_observation['critic:exploration_rates'] = np.array([np.exp(-1 * self.env.beta * self.env.steps_since_restart[agent]) for agent in range(self.env.num_agents)])
        return full_observation


    def set_reward_step_mode(self):
        self.env.this_step_mode = "random" if random.uniform(0, 1) < self.reward_step_random_price_prob else "standard"
        if self.is_evaluation_env: self.env.this_step_mode = "argmax"


    def step(self, action):

        self.tot_num_steps += 1

        if self.tot_num_steps % 100000 == 0:
            print("StackMDPWrapper steps completed: ", self.tot_num_steps)

        if self.eq_steps_counter < self.tot_num_eq_steps:

            # We do an equilibrium step
            self.eq_steps_counter+=1

            if self.eq_steps_counter == self.tot_num_eq_steps: self.set_reward_step_mode()

            obs, _, _, info = self.env.step(action)

            full_observation = OrderedDict({"base_environment": obs['base_environment']})
            if self.critic_obs == "flag" or self.critic_obs == "Q_matrix":
                full_observation["critic:is_reward_step"] = 0
            if self.critic_obs == "Q_matrix":
                full_observation["critic:Q_matrices"] = self.q_matrices_to_norm_vec()
                full_observation['critic:exploration_rates'] = np.array(
                    [np.exp(-1 * self.env.beta * self.env.steps_since_restart[agent]) for agent in
                     range(self.env.num_agents)])

            info["exclude_from_buffer"] = False

            # Checks if step should be excluded from buffer
            if len(self.excluded_indexes) > 0 and self.eq_steps_counter == self.excluded_indexes[0]:
                self.excluded_indexes.pop(0)
                info["exclude_from_buffer"] = True

            return full_observation, 0, False, info

        elif self.reward_steps_counter < self.tot_num_reward_steps:

            # We do a reward step
            self.reward_steps_counter+=1

            self.set_reward_step_mode()
            obs, reward, _, info = self.env.step(action)

            info["exclude_from_buffer"] = False

            full_observation = OrderedDict({"base_environment": obs['base_environment']})
            if self.critic_obs == "flag" or self.critic_obs == "Q_matrix":
                full_observation["critic:is_reward_step"] = 1
            if self.critic_obs == "Q_matrix":
                full_observation["critic:Q_matrices"] = self.q_matrices_to_norm_vec()
                full_observation['critic:exploration_rates'] = np.array([np.exp(-1 * self.env.beta * self.env.steps_since_restart[agent]) for agent in range(self.env.num_agents)])
            done = True if self.reward_steps_counter==self.tot_num_reward_steps else False

            if self.do_log:
                self.log_info(action, info)

        return full_observation, reward, done, info


    def log_info(self, action, info):

        self.logger.record("frac_displayed_agents", self.compute_frac_displayed_agents())
        self.logger.record("count_steps", self.tot_num_steps)
        self.logger.record("supervisor_action", action)
        self.logger.record("coonsumer_surplus", info["surplus"])
        self.logger.record("agent0_reward", info["reward_pricing_agents"]["agent_0"])
        self.logger.record("agent1_reward", info["reward_pricing_agents"]["agent_1"])
        self.logger.record("c_i", round(self.env.c_i, 2))

        for j in range(self.env.num_agents):
            self.logger.record("price_"+str(j), np.where(self.env.action_price_space == self.env.prices[j])[0][0])

        prices = [np.where(self.env.action_price_space == self.env.prices[j])[0][0] for j in range(self.env.num_agents)]
        obs_sub_env = self.adapt_price_array(np.array(prices))
        full_observation = OrderedDict({"base_environment": obs_sub_env})
        if self.critic_obs == "flag" or self.critic_obs == "Q_matrix":
            full_observation["critic:is_reward_step"] = 0
        if self.critic_obs == "Q_matrix":
            full_observation["critic:Q_matrices"] = self.q_matrices_to_norm_vec()
            full_observation['critic:exploration_rates'] = np.array(
                [np.exp(-1 * self.env.beta * self.env.steps_since_restart[agent]) for agent in
                 range(self.env.num_agents)])

        arg_max_action = self.model.predict(full_observation, deterministic=True)[0]
        self.logger.record("arg_max_action", arg_max_action)

        self.logger.dump()

    def compute_frac_displayed_agents(self):
        num_displayed_agents = 0
        for i in range(self.env.m):
            for j in range(self.env.m):
                price_i = self.env.action_price_space[i]
                price_j = self.env.action_price_space[j]
                prices = np.array([price_i, price_j])
                obs_sub_env = self.adapt_price_array(np.array([i, j]))
                full_observation = OrderedDict({"base_environment": obs_sub_env})
                if self.critic_obs == "flag" or self.critic_obs == "Q_matrix":
                    full_observation["critic:is_reward_step"] = 0
                if self.critic_obs == "Q_matrix":
                    full_observation["critic:Q_matrices"] = self.q_matrices_to_norm_vec()
                    full_observation['critic:exploration_rates'] = np.array([np.exp(-1 * self.env.beta * self.env.steps_since_restart[agent]) for agent in range(self.env.num_agents)])
                temp_action = self.model.predict(full_observation, deterministic=True)[0]
                num_displayed_agents = num_displayed_agents + len(self.env.get_bbx_idx(prices, temp_action))

        frac_displayed_agents = num_displayed_agents / (math.pow(self.m, self.env.num_agents)*self.env.num_agents)
        return frac_displayed_agents



class RestartExplorationRateWrapper(gym.Wrapper):
    """Wrapper that controls exploration rate restarts
    Args:
        env (StackMDPWrapper): wrapped environment
        restart_on_reset (bool): if true, restarts exploration at the beginning of each episode
        restart_per_episode_rate (float): in expectation, each agent restarts its exploration rate restart_per_episode_rate times per episode
    """

    def __init__(
            self,
            env,
            restart_per_episode_rate=-1,
    ):
        super(RestartExplorationRateWrapper, self).__init__(env)

        self.restart_per_episode_rate = restart_per_episode_rate


    def reset(self):
        if self.restart_per_episode_rate == -1:
            for current_env in get_all_wrappers(self.env):
                if type(current_env) == RLSupervisorQPricingWrapper:
                    current_env.steps_since_restart = [0] * self.env.num_agents

        return self.env.reset()


    def step(self, action):
        if self.restart_per_episode_rate > 0:
            for agent in range(self.env.num_agents):
                if random.uniform(0, 1) < self.restart_per_episode_rate/(self.env.tot_num_eq_steps+self.env.tot_num_reward_steps):
                    for current_env in get_all_wrappers(self.env):
                        if type(current_env) == RLSupervisorQPricingWrapper:
                            current_env.steps_since_restart[agent] = 0

        obs, reward, done, info = self.env.step(action)

        return obs, reward, done, info


class RestartExplorationRateWrapperNonEpisodic(gym.Wrapper):
    """Wrapper that controls exploration rate restarts
    Args:
        env (StackMDPWrapper): wrapped environment
        restart_on_reset (bool): if true, restarts exploration at the beginning of each episode
        restart_per_episode_rate (float): in expectation, each agent restarts its exploration rate restart_per_episode_rate times per episode
    """

    def __init__(
            self,
            env,
            restart_per_episode_rate=-1,
            expected_num_steps_between_restarts=5030,
    ):
        super(RestartExplorationRateWrapperNonEpisodic, self).__init__(env)
        self.restart_per_episode_rate = restart_per_episode_rate
        self.expected_num_steps_between_restarts = expected_num_steps_between_restarts


    def reset(self):
        self.count_steps=0

        for current_env in get_all_wrappers(self.env):
            if type(current_env) == RLSupervisorQPricingWrapper:
                current_env.steps_since_restart = [0] * self.env.num_agents

        return self.env.reset()


    def step(self, action):
        self.count_steps=self.count_steps+1

        if self.restart_per_episode_rate > 0:
            for agent in range(self.env.num_agents):
                if random.uniform(0, 1) < self.restart_per_episode_rate/(self.expected_num_steps_between_restarts):
                    for current_env in get_all_wrappers(self.env):
                        if type(current_env) == RLSupervisorQPricingWrapper:
                            current_env.steps_since_restart[agent] = 0

        if self.restart_per_episode_rate ==- 1 and self.count_steps%self.expected_num_steps_between_restarts==0:
            for current_env in get_all_wrappers(self.env):
                if type(current_env) == RLSupervisorQPricingWrapper:
                    current_env.steps_since_restart = [0] * self.env.num_agents

        obs, reward, done, info = self.env.step(action)

        return obs, reward, done, info


class EvaluationAfterJohnsonConvergence(gym.Wrapper):

    def __init__(
            self,
            env,
            max_Q_steps=50000000,
    ):
        super(EvaluationAfterJohnsonConvergence, self).__init__(env)
        self.max_Q_steps=max_Q_steps
        self.total_log_steps=30

        for current_env in get_all_wrappers(self.env):
            if type(current_env) == RLSupervisorQPricingWrapper:
                current_env.check_for_convergence = True


    def reset(self):
        self.steps_counter=0
        self.log_steps_counter=0
        self.convergence_flag = False

        # If we are wrapping StackMDP, make sure you keep equilibrium phase and do not log
        for current_env in get_all_wrappers(self.env):
            if type(current_env) == StackMDPWrapper:
                current_env.tot_num_eq_steps = self.max_Q_steps
                current_env.do_log = False

        return self.env.reset()


    def step(self, action):
        self.steps_counter = self.steps_counter+1
        obs, reward, done, info = self.env.step(action)

        if info["johnson_convergence"]: self.convergence_flag = True

        if self.convergence_flag or self.steps_counter>=self.max_Q_steps:
            self.log_steps_counter = self.log_steps_counter+1

            for current_env in get_all_wrappers(self.env):
                if type(current_env) == RLSupervisorQPricingWrapper:
                    current_env.this_step_mode = "argmax"

            if self.log_steps_counter>1: self.log_info(info) # We don't log at the very first step as current prices are not chosen via argmax

            if self.log_steps_counter == self.total_log_steps+1: done = True

        return obs, reward, done, info


    def log_info(self, info):
        self.logger.record("consumer_surplus", info["surplus"])
        self.logger.record("agent0_reward", info["reward_pricing_agents"]["agent_0"])
        self.logger.record("agent1_reward", info["reward_pricing_agents"]["agent_1"])
        self.logger.record("c_i", round(self.c_i, 2))
        for j in range(self.env.num_agents):
            self.logger.record("price_" + str(j), np.where(self.env.action_price_space == self.env.prices[j])[0][0])
        self.logger.dump()